{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import logging\n",
    "import argparse\n",
    "import subprocess\n",
    "from com.yahoo.ml.tf import TFCluster\n",
    "import mnist_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "reload(logging)\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"-e\", \"--epochs\", help=\"number of epochs\", type=int, default=1)\n",
    "parser.add_argument(\"-i\", \"--images\", help=\"HDFS path to MNIST images in parallelized format\")\n",
    "parser.add_argument(\"-l\", \"--labels\", help=\"HDFS path to MNIST labels in parallelized format\")\n",
    "parser.add_argument(\"-f\", \"--format\", help=\"example format\", choices=[\"csv\",\"pickle\",\"tfr\"], default=\"csv\")\n",
    "parser.add_argument(\"-m\", \"--model\", help=\"HDFS path to save/load model during train/test\", default=\"mnist_model\")\n",
    "parser.add_argument(\"-r\", \"--readers\", help=\"number of reader/enqueue threads\", type=int, default=1)\n",
    "parser.add_argument(\"-s\", \"--steps\", help=\"maximum number of steps\", type=int, default=500)\n",
    "parser.add_argument(\"-X\", \"--mode\", help=\"train|inference\", default=\"train\")\n",
    "parser.add_argument(\"-c\", \"--rdma\", help=\"use rdma connection\", default=False)\n",
    "num_executors = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#remove existing models if any\n",
    "subprocess.call([\"rm\", \"-rf\", \"mnist_model\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 213808\n",
      "-rw-r--r--  1 afeng  staff         0 Feb  8 14:52 _SUCCESS\n",
      "-rw-r--r--  1 afeng  staff   9338236 Feb  8 14:52 part-00000\n",
      "-rw-r--r--  1 afeng  staff  11231804 Feb  8 14:52 part-00001\n",
      "-rw-r--r--  1 afeng  staff  11214784 Feb  8 14:52 part-00002\n",
      "-rw-r--r--  1 afeng  staff  11226100 Feb  8 14:52 part-00003\n",
      "-rw-r--r--  1 afeng  staff  11212767 Feb  8 14:52 part-00004\n",
      "-rw-r--r--  1 afeng  staff  11173834 Feb  8 14:52 part-00005\n",
      "-rw-r--r--  1 afeng  staff  11214285 Feb  8 14:52 part-00006\n",
      "-rw-r--r--  1 afeng  staff  11201024 Feb  8 14:52 part-00007\n",
      "-rw-r--r--  1 afeng  staff  11194141 Feb  8 14:52 part-00008\n",
      "-rw-r--r--  1 afeng  staff  10449019 Feb  8 14:52 part-00009\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#verify training images\n",
    "train_images_files = \"csv/train/images\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", train_images_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 4688\n",
      "-rw-r--r--  1 afeng  staff       0 Feb  8 14:52 _SUCCESS\n",
      "-rw-r--r--  1 afeng  staff  204800 Feb  8 14:52 part-00000\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00001\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00002\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00003\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00004\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00005\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00006\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00007\n",
      "-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00008\n",
      "-rw-r--r--  1 afeng  staff  229120 Feb  8 14:52 part-00009\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#verify training labels\n",
    "train_labels_files = \"csv/train/labels\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", train_labels_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "04:44:41 INFO:Reserving TFSparkNodes w/ TensorBoard\n",
      "04:44:52 INFO:TensorBoard running at: http://notforever-lm:56451\n"
     ]
    }
   ],
   "source": [
    "#reserver a cluster for training\n",
    "cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "04:45:04 INFO:Starting TensorFlow\n",
      "04:45:10 INFO:Feeding training data\n"
     ]
    }
   ],
   "source": [
    "#Check out tensorboard at http://localhost:<tb_port> per above during the execution of this step.\n",
    "#It may wait a little for TensorFlow TAG to be loaded.\n",
    "#\n",
    "args = parser.parse_args(['--mode', 'train', \n",
    "                          '--images', train_images_files, \n",
    "                          '--labels', train_labels_files])\n",
    "cluster.start(mnist_dist.map_fun, args)\n",
    "#Feed data via Spark RDD\n",
    "images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])\n",
    "labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])\n",
    "dataRDD = images.zip(labels)\n",
    "cluster.train(dataRDD, args.epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "04:46:56 INFO:Stopping TensorFlow nodes\n"
     ]
    }
   ],
   "source": [
    "cluster.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 8672\n",
      "-rw-r--r--  1 afeng  staff     263 Feb 12 16:45 checkpoint\n",
      "-rw-r--r--  1 afeng  staff  113289 Feb 12 16:45 graph.pbtxt\n",
      "-rw-r--r--  1 afeng  staff  814164 Feb 12 16:45 model.ckpt-0.data-00000-of-00001\n",
      "-rw-r--r--  1 afeng  staff     372 Feb 12 16:45 model.ckpt-0.index\n",
      "-rw-r--r--  1 afeng  staff   43896 Feb 12 16:45 model.ckpt-0.meta\n",
      "-rw-r--r--  1 afeng  staff  814164 Feb 12 16:45 model.ckpt-115.data-00000-of-00001\n",
      "-rw-r--r--  1 afeng  staff     372 Feb 12 16:45 model.ckpt-115.index\n",
      "-rw-r--r--  1 afeng  staff   43896 Feb 12 16:45 model.ckpt-115.meta\n",
      "-rw-r--r--  1 afeng  staff  814164 Feb 12 16:45 model.ckpt-237.data-00000-of-00001\n",
      "-rw-r--r--  1 afeng  staff     372 Feb 12 16:45 model.ckpt-237.index\n",
      "-rw-r--r--  1 afeng  staff   43896 Feb 12 16:45 model.ckpt-237.meta\n",
      "-rw-r--r--  1 afeng  staff  814164 Feb 12 16:45 model.ckpt-360.data-00000-of-00001\n",
      "-rw-r--r--  1 afeng  staff     372 Feb 12 16:45 model.ckpt-360.index\n",
      "-rw-r--r--  1 afeng  staff   43896 Feb 12 16:45 model.ckpt-360.meta\n",
      "-rw-r--r--  1 afeng  staff  814164 Feb 12 16:45 model.ckpt-483.data-00000-of-00001\n",
      "-rw-r--r--  1 afeng  staff     372 Feb 12 16:45 model.ckpt-483.index\n",
      "-rw-r--r--  1 afeng  staff   43896 Feb 12 16:45 model.ckpt-483.meta\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(subprocess.check_output([\"ls\", \"-l\", \"mnist_model\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 35720\n",
      "-rw-r--r--  1 afeng  staff        0 Feb  8 14:53 _SUCCESS\n",
      "-rw-r--r--  1 afeng  staff  1810248 Feb  8 14:53 part-00000\n",
      "-rw-r--r--  1 afeng  staff  1806102 Feb  8 14:53 part-00001\n",
      "-rw-r--r--  1 afeng  staff  1811128 Feb  8 14:53 part-00002\n",
      "-rw-r--r--  1 afeng  staff  1812952 Feb  8 14:53 part-00003\n",
      "-rw-r--r--  1 afeng  staff  1810946 Feb  8 14:53 part-00004\n",
      "-rw-r--r--  1 afeng  staff  1835497 Feb  8 14:53 part-00005\n",
      "-rw-r--r--  1 afeng  staff  1845261 Feb  8 14:53 part-00006\n",
      "-rw-r--r--  1 afeng  staff  1850655 Feb  8 14:53 part-00007\n",
      "-rw-r--r--  1 afeng  staff  1852712 Feb  8 14:53 part-00008\n",
      "-rw-r--r--  1 afeng  staff  1833942 Feb  8 14:53 part-00009\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#verify test images\n",
    "test_images_files = \"csv/test/images\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", test_images_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 800\n",
      "-rw-r--r--  1 afeng  staff      0 Feb  8 14:53 _SUCCESS\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00000\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00001\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00002\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00003\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00004\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00005\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00006\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00007\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00008\n",
      "-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00009\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#verify test labels\n",
    "test_labels_files = \"csv/test/labels\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", test_labels_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "04:48:56 INFO:Reserving TFSparkNodes w/ TensorBoard\n",
      "04:49:06 INFO:TensorBoard running at: http://notforever-lm:56990\n"
     ]
    }
   ],
   "source": [
    "#reserver cluster for inference\n",
    "cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "04:50:00 INFO:Starting TensorFlow\n",
      "04:50:06 INFO:Feeding inference data\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['2017-02-12T16:50:08.179387 Label: 7, Prediction: 7',\n",
       " '2017-02-12T16:50:08.179501 Label: 2, Prediction: 2',\n",
       " '2017-02-12T16:50:08.179547 Label: 1, Prediction: 1',\n",
       " '2017-02-12T16:50:08.179587 Label: 0, Prediction: 0',\n",
       " '2017-02-12T16:50:08.179626 Label: 4, Prediction: 4',\n",
       " '2017-02-12T16:50:08.179665 Label: 1, Prediction: 1',\n",
       " '2017-02-12T16:50:08.179704 Label: 4, Prediction: 4',\n",
       " '2017-02-12T16:50:08.179742 Label: 9, Prediction: 9',\n",
       " '2017-02-12T16:50:08.179780 Label: 5, Prediction: 6',\n",
       " '2017-02-12T16:50:08.179818 Label: 9, Prediction: 9',\n",
       " '2017-02-12T16:50:08.179856 Label: 0, Prediction: 0',\n",
       " '2017-02-12T16:50:08.179895 Label: 6, Prediction: 6',\n",
       " '2017-02-12T16:50:08.179933 Label: 9, Prediction: 9',\n",
       " '2017-02-12T16:50:08.179971 Label: 0, Prediction: 0',\n",
       " '2017-02-12T16:50:08.180008 Label: 1, Prediction: 1',\n",
       " '2017-02-12T16:50:08.180046 Label: 5, Prediction: 5',\n",
       " '2017-02-12T16:50:08.180084 Label: 9, Prediction: 9',\n",
       " '2017-02-12T16:50:08.180122 Label: 7, Prediction: 7',\n",
       " '2017-02-12T16:50:08.180160 Label: 3, Prediction: 3',\n",
       " '2017-02-12T16:50:08.180198 Label: 4, Prediction: 4']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Check out tensorboard at http://localhost:<tb_port> per above\n",
    "#\n",
    "args = parser.parse_args(['--mode', 'inference', \n",
    "                          '--images', test_images_files, \n",
    "                          '--labels', test_labels_files])\n",
    "cluster.start(mnist_dist.map_fun, args)\n",
    "#prepare data as Spark RDD\n",
    "images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])\n",
    "labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])\n",
    "dataRDD = images.zip(labels)\n",
    "#feed data for inference\n",
    "prediction_results = cluster.inference(dataRDD)\n",
    "prediction_results.take(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "09:22:36 INFO:Stopping TensorFlow nodes\n"
     ]
    }
   ],
   "source": [
    "cluster.shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
